Skip to content

Conversation

@dlwh
Copy link
Member

@dlwh dlwh commented Dec 5, 2025

This PR introduces Grugformer: a “grug-simple” JAX LM implementation that leans into explicit sharding and top-level functions rather than heavy abstractions. It adds a minimal core (levanter.grug) plus a small adapter (levanter.models.grug_wrapper) so it can run through the existing Levanter trainer pipeline, and it includes speedrun entrypoints + tests that lock down the intended “grug core surface”.

What’s Included

New Grug core (minimal, notebook-like)

  • New package: lib/levanter/src/levanter/grug/
    • attention.py: Grug-local AttentionMask spec + attention implementation (TPU Splash when on TPU; reference fallback otherwise).
    • model.py: parameter dataclasses + init/forward/activations/loss functions.
    • loss.py: blockwise “large vocab friendly” CE path (avoid full logits materialization; see note below on tradeoffs).
    • data.py, main.py: minimal training/data wiring to run in-repo.
  • Exported surface is intentionally small (functions + dataclasses; minimal mutation).

Levanter adapter

  • lib/levanter/src/levanter/models/grug_wrapper.py: wraps grug core behind Levanter’s LmConfig/trainer expectations while keeping the core itself free of NamedArray-heavy abstractions.

Speedruns / templates

  • experiments/speedrun/grugformer_starter/grugformer_speedrun.py: a grug speedrun template for quick iteration.
  • experiments/speedrun/grugformer_attnsink/grugformer_attn_sink.py: “hackable” grug attention-sink variant (copy/paste edit surface).
  • experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py: head-to-head comparison (Hackable Transformer vs Grugformer, no sinks). Hackable path runs without explicit mesh axes for now.

Tests (lock the “grug core surface”)

  • All Grug tests live under lib/levanter/tests/grug/:
    • test_grugformer_core.py: core API + mesh/sharding sanity.
    • test_grugformer_model_loss.py: loss correctness vs full logits on small shapes; wrapper plumbing.
    • test_grugformer_fused_loss.py: loss-related regression coverage.
    • test_grugformer_compilation.py: lowers/jit-traces model+loss under AbstractMesh (no concrete devices required).
    • test_grugformer.py: higher-level smoke coverage (tiny synthetic step).

Documentation

  • .agents/projects/grugformer.md: principles, intended edit surface, and follow-ups.
  • docs/recipes/change_grug.md: workflow for proposing changes (speedrun edit surface → adopt into canonical grug → archive old experiments).
  • docs/reports/grug-archive.md: lightweight “experiment archive log” placeholder so we have somewhere to record removals/renames as grug evolves.

Notable Design Choices / Current Constraints

  • Attention: TPU path uses Splash attention directly; GPU path uses the reference fallback for now.
  • Loss: large-vocab CE is more painful than we’d like under explicit-sharding; we currently use a blockwise “flash-attention style” transform. The block-size knob is intentionally exposed; we’ve observed meaningful perf sensitivity and will likely revisit this with a better kernel later.

How To Try

  • Run the h2h speedrun:
    • python -m experiments.speedrun.grugformer_vs_hackable_125m.grugformer_vs_hackable_125m
    • Set SR_USE_TPU=1 to use TPU preset.
  • Run tests:
    • uv run pytest lib/levanter/tests/grug -q

Follow-ups

  • Implement a faster large-vocab CE path that’s robust under explicit sharding (avoids the current speed/memory tradeoff).
  • Expand the speedrun “gauntlet” checks and add more minimal “edit points” for experiments.

@github-actions
Copy link
Contributor

This pull request has been inactive for 23 days and is marked as stale.
If there is no further activity within 7 days, it will be automatically closed.
If you believe this PR should remain open, please add a comment or update the PR.

@github-actions github-actions bot added the stale label Dec 29, 2025
@dlwh
Copy link
Member Author

dlwh commented Dec 29, 2025

bump

@github-actions github-actions bot removed the stale label Dec 30, 2025
@pc0618
Copy link
Contributor

pc0618 commented Jan 11, 2026

Pushed fix for TPU Splash attention crashing during init on tracers (fallback to when is unavailable) + removed unsupported arg from the Grugformer speedrun wrapper. Commit: 0ee618f.

@pc0618
Copy link
Contributor

pc0618 commented Jan 11, 2026

Follow-up (previous comment had shell quoting issues): fix uses x.sharding when available and falls back to x.aval.sharding for tracers during staging; also stops passing tie_embeddings into GrugModelConfig (it is kept only for param counting). Commit: 0ee618f.

@pc0618
Copy link
Contributor

pc0618 commented Jan 12, 2026

Added an inline note + refactor in levanter/grug/model.py:init_parameters to use hierarchical key splitting instead of (3 + 7 * num_layers) “magic number” math (more robust to future parameter additions). Commit: b9756b3. Also left a TODO in-code to add a brief explanation in the PR discussion later.

Copy link
Contributor

@ravwojdyla ravwojdyla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! I may be a little aggressive with the comments to delete "unused" logic/options or reduce number of files - this is mostly in spirit of karpathy-ish code 🙇 There's a couple of logic questions in here. Some nits as well, like the __all__, which I dislike 1.

Footnotes

  1. I prefer to protected-ish _, but if marin has a policy on __all__ I'm happy to adjust.

testpaths = ["tests", "experiments"]

# Make sure we timeout before CI kills us, and don't run TPU or slow tests by default
addopts = "--session-timeout=480 -m 'not tpu_ci and not slow'"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional?

runtime_dict = {
"working_dir": current_dir,
"config": {"setup_timeout_seconds": 1800},
"excludes": [".git", "tests/", "docs/", "**/*.pack", "lib/levanter/docs"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ what is the purpose of this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so i can run tests with ray_run, which i do do

@ravwojdyla
Copy link
Contributor

FYI when I run the starter speedrun (130M only) in us-central1 on TPU (v5p-8), I get OOM:

Total hbm usage >= 101.99G:
    reserved        263.00M
    program         101.73G
    arguments            0B

I can work around this but I wonder if that was supposed to work?


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run the Grug trainer.")
parser.add_argument("--cache-dir", type=str, default=None, help="Optional TreeCache directory for real data.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make it simpler to run main.py in isolation without depending on TreeCache or synthetic data? I.e. point it at dir (object store comp) dump of some canonical dataset, e.g. OpenWebText, Fineweb or TinyStories even?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd rather not spend too much time on this given that the main we we'll be running is via marin's training harness.

@ravwojdyla ravwojdyla mentioned this pull request Jan 16, 2026
Copy link
Contributor

@ravwojdyla ravwojdyla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more comments from experiments

Comment on lines 111 to 113
# Grug core currently always has separate token embed + output projection; keep this knob
# for param counting / compatibility with other LmConfig-based scripts.
tie_embeddings: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be more intuitive to not expose this config and instead hard code logic in total_trainable_params? Otherwise it may seem like this flag does something, when it doesn't?

num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_seq_len=self.max_seq_len,
tie_embeddings=self.tie_embeddings,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tie_embeddings doesn't exist in GrugModelConfig (see other comment)

labels = jnp.concatenate([token_ids[:, 1:], token_ids[:, :1] * 0], axis=1).astype(jnp.int32)
loss_weight = loss_weight.astype(loss_dtype)

# NOTE: `block_size=None` corresponds to a single full-vocab block. On the 125M speedrun,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2315 (comment) ptal, I can't reproduce this 🙏

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw setting cross_entropy_block_size to say ~32k on v5p-8 OOMs in 125M experiment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah somehow this version of fused cross entropy doesn't actually work super well? don't really understand why

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm gonna replace with a pallas kernel at some point


## Working Agreement: How Grug Evolves

- Canonical “best guess” lives in `lib/levanter/src/levanter/grug/`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grug is a reference - should all the references live in references or something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. let's not do that here though since the reorg doesn't exist yet?

- Minimal surface: plain pytrees + explicit mesh + small config.
- Owns data loading, checkpointing, and evaluation in a way that’s easy to copy/paste into speedrun scripts.

2) **Evolve Levanter/Marin to support grug natively**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially dumb question, but why can't we make it so we call grug directly instead of even dealing with levanter? Dataloading can be factored out into its own thing maybe?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is the goal yes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is a good checkpoint and a next step is to isolate pieces of levanter that are still useful and make them more grug-aware

if isinstance(mask, AttentionMask):
mask = mask.materialize_mask(scores.shape[-2], scores.shape[-1])

if mask is not None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kinda hate this. should maybe simplify i dunno

- Minimal surface: plain pytrees + explicit mesh + small config.
- Owns data loading, checkpointing, and evaluation in a way that’s easy to copy/paste into speedrun scripts.

2) **Evolve Levanter/Marin to support grug natively**
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is the goal yes

- Minimal surface: plain pytrees + explicit mesh + small config.
- Owns data loading, checkpointing, and evaluation in a way that’s easy to copy/paste into speedrun scripts.

2) **Evolve Levanter/Marin to support grug natively**
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is a good checkpoint and a next step is to isolate pieces of levanter that are still useful and make them more grug-aware


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run the Grug trainer.")
parser.add_argument("--cache-dir", type=str, default=None, help="Optional TreeCache directory for real data.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd rather not spend too much time on this given that the main we we'll be running is via marin's training harness.

DEFAULT_AXIS_MAPPING = {"batch": ("replica_dcn", "replica", "data")}


def make_token_dataset(cache: TreeCache[dict], *, seq_len: int) -> TokenSeqDataset:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Food for thought: it's not clear how a user would do clever data loading tricks here such as @ClassicLarry's Document Alignment. Fine if we decide we want to grugify that part later?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More generally, the data loader seems pretty non-groggy to me since the user still has to go to levanter to figure out what these return types are and how to use them if they did want to make stuff custom here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd like to do later yes please

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gruggifying data should be later

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to take a pass at Gruggifying data stuff if you'd like since I am at least somewhat familiar with the types from writing the audio loader

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was in the middle of a different branch to clean it up for other purposes, maybe after that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever works best - just don't want you to feel like you are responsible for all gruggifying if you don't want to be

dlwh added 3 commits January 21, 2026 16:59
# Conflicts:
#	lib/marin/src/marin/rl/weight_transfer/arrow_flight.py
#	uv.lock
import jax.numpy as jnp
from einops import rearrange
from jax import random
from jax.sharding import PartitionSpec as P, reshard
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I don't love the alias P here because 1) it's hard to grep for and 2) I don't usually read new code header first, so it's hard to know what this is on first pass.

Copy link
Contributor

@ravwojdyla ravwojdyla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants